#!/usr/bin/python

################################################################
### Align transcriptions with lattices.
################################################################

import sys,os,re,glob,math,glob,signal,traceback,sqlite3
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from scipy.ndimage import interpolation,measurements
from pylab import *
from optparse import OptionParser
from multiprocessing import Pool
import ocrolib
from ocrolib import number_of_processors,fstutils,die,docproc
from scipy import stats
from ocrolib import dbhelper,ocroio,ocrofst
from ocrolib import ligatures,rect_union,Record
import multiprocessing

signal.signal(signal.SIGINT,lambda *args:sys.exit(1))

import argparse
parser = argparse.ArgumentParser(description = """
%prog [-s gt] [-p] [args] *.gt.txt
""")

parser.add_argument("-x","--extract",help="extract characters",default=None)
parser.add_argument("-X","--noextract",help="don't actually write files",action="store_true")
parser.add_argument("-B","--beam",help="size of beam",type=int,default=100)

parser.add_argument("-g","--gt",help="extension for ground truth",default=None)
parser.add_argument("-p","--pagegt",help="arguments are page ground truth",action="store_true")
parser.add_argument("-l","--langmod",help="language model",default=None)

parser.add_argument("-s","--suffix",help="output suffix for writing result",default=None)
parser.add_argument("-O","--overwrite",help="overwrite outputs",action="store_true")

parser.add_argument("-P","--perc",help="percentile for reporting statistics",type=float,default=90.0)
parser.add_argument("-M","--maxperc",help="maximum cost at percentile",type=float,default=5.0)
parser.add_argument("-A","--maxavg",help="maximum average cost",type=float,default=5.0)

parser.add_argument("-c","--mmcost",help="mismatch cost",type=float,default=2.0)
parser.add_argument("-C","--edcost",help="edit cost",type=float,default=5.0)

parser.add_argument("-Q","--parallel",type=int,default=multiprocessing.cpu_count(),help="number of parallel processes to use")

parser.add_argument("--debug_line",action="store_true")
parser.add_argument("--debug_rawalign",action="store_true")
parser.add_argument("--debug_aligned",action="store_true")
parser.add_argument("--debug_select",action="store_true")
parser.add_argument("-D","--Display",help="display",action="store_true")
parser.add_argument('--dgrid',default=9,help="grid size for display")

parser.add_argument("args",default=[],nargs='*',help="input lines)")
# args = parser.parse_args(["-D","-y","-x","temp.db"])
args = parser.parse_args()

assert not os.path.exists(args.extract),"%s: already exists; please remove"%args.extract

class DbWriter:
    def __init__(self,fname):
        if not os.path.exists(fname):
            db = sqlite3.connect(fname,timeout=600.0)
            dbhelper.charcolumns(db,"chars")
            db.commit()
        else:
            db = sqlite3.connect(fname,timeout=600.0)
        self.db = db
    def __enter__(self):
        return self
    def __exit__(self,*args):
        self.db.commit()
        self.db.close()
        del self.db
    def insert(self,image,cls,cost=0.0,count=0,file=None,lgeo=None,rel=None,bbox=None):
        dbhelper.dbinsert(self.db,"chars",
                          image=dbhelper.image2blob(image),
                          cost=float(cost),
                          cls=cls,
                          count=1,
                          file=file,
                          lgeo="%g %g %g"%lgeo if lgeo else None,
                          rel="%g %g %g"%rel if rel else None,
                          bbox="%d %d %d %d"%bbox if bbox else None)

import codecs
import openfst
from ocrolib.fstutils import explode_transcription,epsilon,space,sigma,add_between,optimize_openfst,openfst2ocrofst

def alignment_fst(line):
    fst = ocrofst.OcroFST()
    state = fst.AddState()
    fst.SetStart(state)
    states = [state]
    for i in range(len(line)):
        states.append(fst.AddState())
    for i in range(len(line)):
        s = line[i]
        c = ord(s)
        start = states[i]
        next = states[i+1]

        # space is special (since we use separate skip/insertion self)

        # insertion of space
        fst.AddArc(next,space,space,0.0,next)
        # insertion of other character
        fst.AddArc(start,sigma,ord("~"),args.edcost,start)

        if s==" ":
            # space transition
            fst.AddArc(start,space,space,0.0,next)
            # skip space
            fst.AddArc(start,epsilon,space,0.0,next)
        else:
            # add character transition
            fst.AddArc(start,c,c,0.0,next)
            # reject insertion
            fst.AddArc(start,ord("~"),ord("~"),args.edcost,start)
            # mismatch between input and transcription
            fst.AddArc(start,sigma,c,args.mmcost,next)
            # deletion in lattice
            fst.AddArc(start,epsilon,c,args.edcost,next)
    last = states[-1]
    fst.AddArc(last,sigma,epsilon,args.edcost,last)
    fst.SetFinal(last,0.0)
    return fst

debug = 0

def align1(job):
    fname,gtfile = job
    # read the ground truth data and construct an FST
    if not os.path.exists(gtfile):
        print gtfile,": NOT FOUND"
        return
    with open(gtfile) as stream:
        gttext = stream.read()[:-1]


    fst_file = ocrolib.fvariant(fname,"fst")
    if not os.path.exists(fst_file):
        print fst_file,": NOT FOUND"
        return
    fst = ocrofst.OcroFST()
    fst.load(fst_file)

    bestcost = 1e38
    bestline = None
    best = None
    for i,line in enumerate(gttext.split("\n")):
        line = line.strip()
        line = re.sub(r'[~_ \t]+',' ',line)
        if len(line)<=1: continue
        gtfst = alignment_fst(line)
        # actually perform the alignment
        result = ocrofst.beam_search(fst,gtfst,100)
        if args.debug_line:
            print "line",i,len(result[0]),sum(result[4])
        if len(result[0])<=1: continue
        avg = sum(result[4])/len(result[0]+10.0)
        if avg>=bestcost: continue
        bestcost = avg
        bestline = line
        best = result

    v1,v2,ins,outs,costs = best

    if args.debug_rawalign:
        for i in range(len(v1)):
            print "raw-align %3d [%3d %3d] (%3d %3d) %6.2f"%(i,v1[i],v2[i],ins[i]>>16,ins[i]&0xffff,costs[i]),unichr(outs[i])

    sresult = []
    scosts = []
    segs = []

    # Here is the general idea:
    # - low-cost correspondences are output directly
    # - stretches of high-cost correspondences and epsilons are output together

    n = len(ins)
    i = 1
    while i<n:
        j = i+1
        if outs[i]==ord(" "):
            sresult.append(" ")
            scosts.append(costs[i])
            segs.append((0,0))
            i = j
            continue
        if 0 and outs[i]==ord("~"):
            # may not need special treatment for "~"
            prev = [x for x in ins[:i-1] if x!=0]
            succ = [x for x in ins[i+1:] if x!=0]
            if len(prev)>0 and ins[i]==prev[-1] or len(succ)>0 and ins[i]==succ[0]:
                outs[i] = 0
            else:
                sresult.append("~")
                scosts.append(costs[i])
                segs.append((0,0))
                i = j
                continue
        if costs[i]>2.0 or ins[i]==0:
            seg = ins[i]
            while j<n:
                # only output combined segments as characters
                if seg!=0 and ins[j]!=seg: break
                # break at spaces
                if outs[j]==ord(" "): break
                # keep track of current segment (skip may precede or follow segment)
                if seg==0 : seg = ins[j]
                j += 1
        cls = "".join([unichr(x) for x in outs[i:j] if x!=0])
        sresult.append(cls)
        scosts.append(sum(costs[i:j]))
        start = min([x>>16 for x in ins[i:j] if x!=0]+[9999])
        end = max([x&0xffff for x in ins[i:j] if x!=0]+[0])
        segs.append((start,end))
        i = j

    if args.debug_aligned:
        for i,row in enumerate(zip(sresult,scosts,segs)):
            print "aligned",i,row

    ml = max([len(x) for x in sresult])
    lig = sum([len(x)>1 for x in sresult])
    bad = sum([x=="~" for x in sresult])
    perc = stats.scoreatpercentile(costs,args.perc)
    avg = mean(costs)
    skip = (perc > args.maxperc or avg>args.maxavg or ml>3 or lig*1.0/len(sresult)>0.2 or bad*1.0/(3.0+len(sresult))>0.1)
    aligned = fstutils.implode_transcription(sresult,maxlig=100)
    print "%c%s %6.2f %6.2f:"%("*" if skip else " ",fname,perc,avg),aligned
    if skip: return

    # read the raw segmentation
    rseg_file = ocrolib.fvariant(fname,"rseg")
    if not os.path.exists(rseg_file):
        print rseg_file,": NOT FOUND"
        return
    rseg = ocroio.read_line_segmentation(rseg_file)
    rseg_boxes = docproc.seg_boxes(rseg)

    # Now run through the segments and create a table that maps rseg
    # labels to the corresponding output element.

    assert len(sresult)==len(segs)
    assert len(scosts)==len(segs)

    bboxes = []

    rmap = zeros(amax(rseg)+1,'i')
    for i in range(1,len(segs)):
        start,end = segs[i]
        if start==0 or end==0: continue
        rmap[start:end+1] = i
        bboxes.append(rect_union(rseg_boxes[start:end+1]))
    assert rmap[0]==0

    cseg = zeros(rseg.shape,'i')
    for i in range(cseg.shape[0]):
        for j in range(cseg.shape[1]):
            cseg[i,j] = rmap[rseg[i,j]]

    assert len(segs)==len(sresult) 
    assert len(segs)==len(scosts)

    assert amin(cseg)==0,"amin(cseg)!=0 (%d,%d)"%(amin(cseg),amax(cseg))

    # first work on the list output; here, one list element should
    # correspond to each cost
    result = sresult
    assert len(result)==len(scosts),\
        "output length %d differs from cost length %d"%(len(result),len(costs))
    assert amax(cseg)<len(result),\
        "amax(cseg) %d not consistent with output length %d"%(amax(cseg),len(output_l))

    # if there are spaces at the end, trim them (since they will never have a corresponding cseg)

    while len(result)>0 and result[-1]==" ":
        result = result[:-1]
        costs = costs[:-1]

    perc = stats.scoreatpercentile(costs,50)
    avg = mean(costs)
    skip = (perc > 10.0 or avg>10.0)

    if args.suffix is not None:
        cseg_file = ocrolib.fvariant(fname,"cseg",args.suffix)
        if not args.overwrite:
            if os.path.exists(cseg_file): die("%s: already exists",cseg_file)
        ocrolib.write_line_segmentation(cseg_file,cseg)
        # aligned text line
        ocrolib.write_text(ocrolib.fvariant(fname,"aligned",args.suffix),aligned)
        # per character costs
        with ocrolib.fopen(fname,"costs",args.suffix,mode="w") as stream:
            for i in range(len(costs)):
                stream.write("%d %g\n"%(i,costs[i]))
        # true ground truth (best-matching line in the original transcription)
        ocrolib.write_text(ocrolib.fvariant(fname,"txt",args.suffix),bestline)

    if args.extract is not None:
        iraw = 0
        if args.Display:
            ion(); clf()
        line = ocrolib.read_image_gray(ocrolib.fvariant(fname,"png"))
        line = amax(line)-line
        lgeo = docproc.seg_geometry(rseg)
        grouper = ocrolib.Grouper()
        grouper.setSegmentation(rseg)
        with DbWriter(args.extract) as db:
            for i in range(grouper.length()):
                raw,mask = grouper.extractWithMask(line,i,dtype='B')
                start = grouper.start(i)
                end = grouper.end(i)
                bbox = grouper.boundingBox(i)
                y0,x0,y1,x1 = bbox
                rel = docproc.rel_char_geom((y0,y1,x0,x1),lgeo)
                ry,rw,rh = rel
                assert rw>0 and rh>0
                if (start,end) in segs:
                    index = segs.index((start,end))
                    cls = sresult[index]
                    cost = costs[index]
                else:
                    cls = "~"
                    cost = 0.0
                if cls!="~":
                    if args.Display:
                        iraw += 1
                        subplot(args.dgrid,args.dgrid,iraw)
                        gray(); imshow(raw)
                        ax = gca()
                        ax.text(0.1,0.1,"%s"%cls,transform=ax.transAxes,color='red')
                        print "output %3d %3d cls %-5s cost %6.2f    "%(start,end,cls,cost),
                        print "y %6.2f w %6.2f h %6.2f"%(rel[0],rel[1],rel[2])
                    db.insert(image=raw,
                              cost=float(cost),
                              cls=cls,
                              count=1,
                              file=fname,
                              lgeo=lgeo,
                              rel=rel,
                              bbox=bbox)
    if args.Display: ginput(1,10000)

def safe_align1(job):
    try:
        align1(job)
    except:
        traceback.print_exc()

jobs = []

if args.pagegt:
    if len(args.args)==1 and os.path.isdir(args.args[0]):
        files = glob.glob(args.args[0]+"/*.gt.txt")
    else:
        files = args.args
    for arg in files:
        if not os.path.exists(arg):
            print "# %s: not found"%arg
            continue
        base,_ = ocrolib.allsplitext(arg)
        if not os.path.exists(base) or not os.path.isdir(base):
            print "["+base+"?]"
            sys.stdout.flush()
            continue
        lines = glob.glob(base+"/??????.png")
        for line in lines:
            jobs.append((line,arg))
    print
elif args.langmod:
    for arg in args.args:
        jobs.append((arg,args.langmod))
elif args.gt is not None:
    if len(args.args)==1 and (os.path.isdir(args.args[0]) or os.path.islink(args.args[0])):
        files = glob.glob(args[0]+"/????/??????.png")
    else:
        files = args.args
    for arg in files:
        path,ext = ocrolib.allsplitext(arg)
        p = path+args.gt
        if not os.path.exists(arg):
            print arg,"not found"
            continue
        if not os.path.exists(p): 
            print p,"not found"
            continue
        jobs.append((arg,p))
else:
    raise Exception("you need to specify what kind of groundtruth you want to align with (-p, -l, -g)")
        

jobs = []

if args.pagegt:
    if len(args.args)==1 and os.path.isdir(args.args[0]):
        files = glob.glob(args[0]+"/*.gt.txt")
    else:
        files = args.args
    for arg in files:
        if not os.path.exists(arg):
            print "# %s: not found"%arg
            continue
        base,_ = ocrolib.allsplitext(arg)
        if not os.path.exists(base) or not os.path.isdir(base):
            print "["+base+"?]"
            sys.stdout.flush()
            continue
        lines = glob.glob(base+"/??????.png")
        for line in lines:
            jobs.append((line,arg))
    print
elif args.langmod:
    for arg in args:
        jobs.append((arg,args.langmod))
elif args.gt is not None:
    if len(args.args)==1 and (os.path.isdir(args.args[0]) or os.path.islink(args.args[0])):
        args = glob.glob(args[0]+"/????/??????.png")
    for arg in args.args:
        path,ext = ocrolib.allsplitext(arg)
        p = path+args.gt
        if not os.path.exists(arg):
            print arg,"not found"
            continue
        if not os.path.exists(p): 
            print p,"not found"
            continue
        jobs.append((arg,p))
else:
    raise Exception("you need to specify what kind of groundtruth you want to align with (-p, -l, -g)")

if args.parallel<2:
    for arg in jobs: align1(arg)
else:
    pool = Pool(processes=args.parallel)
    result = pool.map(safe_align1,jobs)
